Skip to content

[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization#31179

Merged
yewentao256 merged 1 commit intovllm-project:mainfrom
c0de128:fix/rocm-fp8-silu-mul-dtype
Dec 24, 2025
Merged

[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization#31179
yewentao256 merged 1 commit intovllm-project:mainfrom
c0de128:fix/rocm-fp8-silu-mul-dtype

Conversation

@c0de128
Copy link
Copy Markdown
Contributor

@c0de128 c0de128 commented Dec 22, 2025

Summary

Fix hardcoded torch.float8_e4m3fn dtype in silu_mul_per_token_group_quant_fp8_colmajor() that causes incorrect dtype and accuracy issues on ROCm platforms using torch.float8_e4m3fnuz.

Problem

The function in vllm/model_executor/layers/quantization/utils/fp8_utils.py was:

  1. Hardcoding torch.float8_e4m3fn dtype for the output tensor (line 629)
  2. Using default finfo.min/max values from torch.float8_e4m3fn (lines 640-642)

On ROCm platforms that use torch.float8_e4m3fnuz:

  • The output tensor has the wrong dtype
  • The fp8 min/max values (240.0) cause accuracy issues - the correct value is 224.0 for fnuz

Solution

Apply the same pattern already used in per_token_group_quant_fp8() in the same file (lines 766-770):

fp8_dtype = current_platform.fp8_dtype()
finfo = torch.finfo(fp8_dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max

Test Plan

This is a consistency fix that aligns the function with the existing ROCm-aware pattern used elsewhere in the same file. The fix ensures:

  • Correct fp8 dtype is used based on platform
  • Correct fp8 min/max values are used for fnuz dtype on ROCm

Related

This is similar to the pattern established for other FP8 quantization functions in this file that already handle ROCm fnuz correctly.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in silu_mul_per_token_group_quant_fp8_colmajor where a hardcoded torch.float8_e4m3fn dtype and its associated min/max values were used, causing accuracy issues on ROCm platforms that expect torch.float8_e4m3fnuz. The changes correctly use current_platform.fp8_dtype() to determine the appropriate float8 data type and apply platform-specific min/max values for quantization, aligning the function's behavior with per_token_group_quant_fp8 and ensuring correctness on ROCm. The changes are correct and well-implemented.

@c0de128 c0de128 changed the title [Bugfix][ROCm] Use platform fp8_dtype in silu_mul_per_token_group_quant_fp8_colmajor [ROCm][Strix Halo] Fix FP8 dtype in silu_mul quantization Dec 22, 2025
@c0de128
Copy link
Copy Markdown
Contributor Author

c0de128 commented Dec 22, 2025

@hongxiayang @jithunnair-amd This is ready for review and addresses critical FP8 dtype handling for ROCm on the new Strix Halo architecture.

@c0de128 c0de128 changed the title [ROCm][Strix Halo] Fix FP8 dtype in silu_mul quantization [ROCm][Strix Halo] Fix for FP8 dtype in silu_mul quantization Dec 22, 2025
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 23, 2025
@c0de128
Copy link
Copy Markdown
Contributor Author

c0de128 commented Dec 24, 2025

Hi maintainers,

This PR has been approved by @yewentao256. The CI failures appear to be known flaky tests (e.g., async-engine-inputs-utils-worker-config-test timeout) that are unrelated to this FP8 dtype fix.

Would it be possible to trigger a merge or a final re-run of the failing jobs? The fix itself is straightforward - using current_platform.fp8_dtype() instead of hardcoded torch.float8_e4m3fn.

Thank you!

@c0de128
Copy link
Copy Markdown
Contributor Author

c0de128 commented Dec 24, 2025

@vllm-bot rerun ci

The async-engine-inputs-utils-worker-config-test failure appears unrelated to this PR (FP8 dtype fix). This change only affects ROCm FP8 quantization, not async engine configuration. Requesting CI rerun.

…nt_fp8_colmajor

The function was hardcoding torch.float8_e4m3fn dtype and using its
default min/max values. On ROCm platforms that use torch.float8_e4m3fnuz,
this causes incorrect dtype and accuracy issues.

This fix:
- Uses current_platform.fp8_dtype() instead of hardcoded dtype
- Applies the same ROCm-aware fp8 min/max logic (224.0 for fnuz) that
  is already used in per_token_group_quant_fp8() in the same file

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
@c0de128 c0de128 force-pushed the fix/rocm-fp8-silu-mul-dtype branch from 2f4658d to f3b8abe Compare December 24, 2025 13:27
@c0de128 c0de128 changed the title [ROCm][Strix Halo] Fix for FP8 dtype in silu_mul quantization [Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization Dec 24, 2025
@yewentao256 yewentao256 merged commit 66c9887 into vllm-project:main Dec 24, 2025
55 checks passed
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Dec 30, 2025
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…project#31179)

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
@c0de128 c0de128 deleted the fix/rocm-fp8-silu-mul-dtype branch January 27, 2026 17:56
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants